import cv2
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pickle

# 读取训练图像和分区图像
image_train = cv2.imread("Sandstone_1.tif", cv2.IMREAD_GRAYSCALE)
segmentation_train = cv2.imread("Sandstone_1_segment.tif", cv2.IMREAD_GRAYSCALE)

# 选择一种特征提取算法（均值滤波、高斯滤波、Sobel边缘检测、Canny边缘检测）
chosen_algorithm = 'gaussian'

# 应用选择的算法
if chosen_algorithm == 'mean':
    processed_image_train = cv2.blur(image_train, (5, 5))  # 5x5均值滤波
elif chosen_algorithm == 'gaussian':
    processed_image_train = cv2.GaussianBlur(image_train, (5, 5), 0)  # 高斯滤波
elif chosen_algorithm == 'sobel':
    processed_image_train = cv2.Sobel(image_train, cv2.CV_64F, 1, 1, ksize=5)  # Sobel边缘检测
elif chosen_algorithm == 'canny':
    processed_image_train = cv2.Canny(image_train, 50, 150)  # Canny边缘检测

# 将处理后的图像转换为一维数组
X_train = processed_image_train.flatten()
y_train = segmentation_train.flatten()

print("图像形状:", image_train.shape)
print("完成从砂岩截面图1及其对应分区中获取X和y")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

print("完成train_test_split")

# 使用随机森林作为分类器
clf = RandomForestClassifier()
clf.fit(X_train.reshape(-1, 1), y_train)

print("完成随机森林模型clf的训练")

# 在测试集上进行预测
y_pred = clf.predict(X_test.reshape(-1, 1))

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print("准确率: ", accuracy)

# 保存训练好的分类器到硬盘
with open('clf_gaussian.pkl', 'wb') as model_file:
    pickle.dump(clf, model_file)

print("已保存随机森林模型clf到硬盘")
